{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp vision.models.xresnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from torchvision.models.utils import load_state_dict_from_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XResnet\n",
    "\n",
    "> Resnet from bags of tricks paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def init_cnn(m):\n",
    "    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n",
    "    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(m.weight)\n",
    "    for l in m.children(): init_cnn(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class XResNet(nn.Sequential):\n",
    "    @delegates(ResBlock)\n",
    "    def __init__(self, block, expansion, layers, p=0.0, c_in=3, n_out=1000, stem_szs=(32,32,64),\n",
    "                 widen=1.0, sa=False, act_cls=defaults.activation, ndim=2, ks=3, stride=2, **kwargs):\n",
    "        store_attr('block,expansion,act_cls,ndim,ks')\n",
    "        if ks % 2 == 0: raise Exception('kernel size has to be odd!')\n",
    "        stem_szs = [c_in, *stem_szs]\n",
    "        stem = [ConvLayer(stem_szs[i], stem_szs[i+1], ks=ks, stride=stride if i==0 else 1, \n",
    "                          act_cls=act_cls, ndim=ndim)\n",
    "                for i in range(3)]\n",
    "\n",
    "        block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]\n",
    "        block_szs = [64//expansion] + block_szs\n",
    "        blocks    = self._make_blocks(layers, block_szs, sa, stride, **kwargs)\n",
    "\n",
    "        super().__init__(\n",
    "            *stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=ndim),\n",
    "            *blocks,\n",
    "            AdaptiveAvgPool(sz=1, ndim=ndim), Flatten(), nn.Dropout(p),\n",
    "            nn.Linear(block_szs[-1]*expansion, n_out),\n",
    "        )\n",
    "        init_cnn(self)\n",
    "\n",
    "    def _make_blocks(self, layers, block_szs, sa, stride, **kwargs):\n",
    "        return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l,\n",
    "                                 stride=1 if i==0 else stride, sa=sa and i==len(layers)-4, **kwargs)\n",
    "                for i,l in enumerate(layers)]\n",
    "\n",
    "    def _make_layer(self, ni, nf, blocks, stride, sa, **kwargs):\n",
    "        return nn.Sequential(\n",
    "            *[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1,\n",
    "                      sa=sa and i==(blocks-1), act_cls=self.act_cls, ndim=self.ndim, ks=self.ks, **kwargs)\n",
    "              for i in range(blocks)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _xresnet(pretrained, expansion, layers, **kwargs):\n",
    "    # TODO pretrain all sizes. Currently will fail with non-xrn50\n",
    "    url = 'https://s3.amazonaws.com/fast-ai-modelzoo/xrn50_940.pth'\n",
    "    res = XResNet(ResBlock, expansion, layers, **kwargs)\n",
    "    if pretrained: res.load_state_dict(load_state_dict_from_url(url, map_location='cpu')['model'], strict=False)\n",
    "    return res\n",
    "\n",
    "def xresnet18 (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [2, 2,  2, 2], **kwargs)\n",
    "def xresnet34 (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [3, 4,  6, 3], **kwargs)\n",
    "def xresnet50 (pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3, 4,  6, 3], **kwargs)\n",
    "def xresnet101(pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3, 4, 23, 3], **kwargs)\n",
    "def xresnet152(pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3, 8, 36, 3], **kwargs)\n",
    "def xresnet18_deep  (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [2,2,2,2,1,1], **kwargs)\n",
    "def xresnet34_deep  (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [3,4,6,3,1,1], **kwargs)\n",
    "def xresnet50_deep  (pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3,4,6,3,1,1], **kwargs)\n",
    "def xresnet18_deeper(pretrained=False, **kwargs): return _xresnet(pretrained, 1, [2,2,1,1,1,1,1,1], **kwargs)\n",
    "def xresnet34_deeper(pretrained=False, **kwargs): return _xresnet(pretrained, 1, [3,4,6,3,1,1,1,1], **kwargs)\n",
    "def xresnet50_deeper(pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3,4,6,3,1,1,1,1], **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "se_kwargs1 = dict(groups=1 , reduction=16)\n",
    "se_kwargs2 = dict(groups=32, reduction=16)\n",
    "se_kwargs3 = dict(groups=32, reduction=0)\n",
    "g0 = [2,2,2,2]\n",
    "g1 = [3,4,6,3]\n",
    "g2 = [3,4,23,3]\n",
    "g3 = [3,8,36,3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def xse_resnet18(n_out=1000, pretrained=False, **kwargs):   return XResNet(SEBlock,  1, g0, n_out=n_out, **se_kwargs1, **kwargs)\n",
    "def xse_resnext18(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 1, g0, n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xresnext18(n_out=1000, pretrained=False, **kwargs):     return XResNet(SEResNeXtBlock, 1, g0, n_out=n_out, **se_kwargs3, **kwargs)\n",
    "def xse_resnet34(n_out=1000, pretrained=False, **kwargs):   return XResNet(SEBlock,  1, g1, n_out=n_out, **se_kwargs1, **kwargs)\n",
    "def xse_resnext34(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 1, g1, n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xresnext34(n_out=1000, pretrained=False, **kwargs):     return XResNet(SEResNeXtBlock, 1, g1, n_out=n_out, **se_kwargs3, **kwargs)\n",
    "def xse_resnet50(n_out=1000, pretrained=False, **kwargs):   return XResNet(SEBlock,  4, g1, n_out=n_out, **se_kwargs1, **kwargs)\n",
    "def xse_resnext50(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 4, g1, n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xresnext50(n_out=1000, pretrained=False, **kwargs):     return XResNet(SEResNeXtBlock, 4, g1, n_out=n_out, **se_kwargs3, **kwargs)\n",
    "def xse_resnet101(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEBlock,  4, g2, n_out=n_out, **se_kwargs1, **kwargs)\n",
    "def xse_resnext101(n_out=1000, pretrained=False, **kwargs): return XResNet(SEResNeXtBlock, 4, g2, n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xresnext101(n_out=1000, pretrained=False, **kwargs):    return XResNet(SEResNeXtBlock, 4, g2, n_out=n_out, **se_kwargs3, **kwargs)\n",
    "def xse_resnet152(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEBlock,  4, g3, n_out=n_out, **se_kwargs1, **kwargs)\n",
    "def xsenet154(n_out=1000, pretrained=False, **kwargs):\n",
    "    return XResNet(SEBlock, g3, groups=64, reduction=16, p=0.2, n_out=n_out)\n",
    "def xse_resnext18_deep  (n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 1, g0+[1,1], n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xse_resnext34_deep  (n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 1, g1+[1,1], n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xse_resnext50_deep  (n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 4, g1+[1,1], n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xse_resnext18_deeper(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 1, [2,2,1,1,1,1,1,1], n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xse_resnext34_deeper(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 1, [3,4,4,2,2,1,1,1], n_out=n_out, **se_kwargs2, **kwargs)\n",
    "def xse_resnext50_deeper(n_out=1000, pretrained=False, **kwargs):  return XResNet(SEResNeXtBlock, 4, [3,4,4,2,2,1,1,1], n_out=n_out, **se_kwargs2, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = xse_resnext18()\n",
    "x = torch.randn(64, 3, 128, 128)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = xresnext18()\n",
    "x = torch.randn(64, 3, 128, 128)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = xse_resnet50()\n",
    "x = torch.randn(8, 3, 64, 64)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = xresnet18(ndim=1, c_in=1, ks=15)\n",
    "x = torch.randn(64, 1, 128)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = xresnext50(ndim=1, c_in=2, ks=31, stride=4)\n",
    "x = torch.randn(8, 2, 128)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = xresnet18(ndim=3, c_in=3, ks=3)\n",
    "x = torch.randn(8, 3, 32, 32, 32)\n",
    "y = tst(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "fastaidev",
   "language": "python",
   "name": "fastaidev"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
